/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
 * Description: graph_infershape_util.cpp
 */

#include "framework/graph/core/infershape/graph_infershape_util.h"

#include "infra/base/assertion.h"

// inc/framework
#include "framework/infra/log/log.h"
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/core/cgraph/graph_sorter.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/edge/edge.h"
#include "framework/graph/core/edge/edge_visitor.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_functor.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/core/infershape/op_ir_facade.h"
#include "framework/graph/core/infershape/op_ir_func_factory.h"

namespace ge {
namespace {
class GraphInfershape : private NodeFunctor {
public:
    explicit GraphInfershape(InferContext& inferContext) : inferContext_(inferContext) {}
    ~GraphInfershape() override = default;

    hiai::Status InferGraphNodes(const ComputeGraph& graph)
    {
        return graph.ROLE(GraphListWalker).WalkAllNodesModifiable([](Node& /* node */) { return true; }, ToVisitor());
    }

private:
    hiai::Status Visit(Node& node) override
    {
        SetInputTensor(node);

        OpVerifyFunc verifyFunc = OpIRFuncFactory::Instance()->GetVerifyFunc(node);
        if (verifyFunc != nullptr && RunFunc(node, verifyFunc) != GRAPH_SUCCESS) {
            for (const std::string& msg : inferContext_.GetVerifyInfo().messages) {
                FMK_LOGE("[op:%s type:%s] Verify failed, %s",
                    node.ROLE(NodeSpec).Name().c_str(), node.ROLE(NodeSpec).Type().c_str(), msg.c_str());
            }
            return GRAPH_FAILED;
        }

        OpInferFunc inferFunc = OpIRFuncFactory::Instance()->GetInferFunc(node);
        HIAI_EXPECT_NOT_NULL(inferFunc);

        return RunFunc(node, inferFunc);
    }

    void SetInputTensor(Node& node)
    {
        auto visitor = [&node](Edge& edge) {
            const TensorDesc& tensor = edge.SrcNode().ROLE(NodeSpec).OpDesc().GetOutputDesc(edge.SrcIdx());
            (void)node.ROLE(NodeSpec).OpDesc().UpdateInputDesc(edge.DstIdx(), tensor);

            return hiai::SUCCESS;
        };

        (void)node.ROLE(NodeWalker).ListInDataEdges(std::move(visitor));
    }

    hiai::Status RunFunc(Node& node, OpInferFunc& func)
    {
        OpIRFacade opIRFacade(node);
        inferContext_.SetOpIRFacade(opIRFacade);

        HIAI_EXPECT_EXEC(func(inferContext_));

        return hiai::SUCCESS;
    }

private:
    InferContext& inferContext_;
};
} // namespace

hiai::Status GraphInfershapeUtil::InferShape(const ge::ComputeGraph& graph)
{
    InferContext inferContext;
    return InferShape(graph, inferContext);
}

hiai::Status GraphInfershapeUtil::InferShape(const ge::ComputeGraph& graph, InferContext& inferContext)
{
    HIAI_EXPECT_EXEC(graph.ROLE(GraphSorter).SortNodesDFS());

    GraphInfershape infer(inferContext);
    return infer.InferGraphNodes(graph);
}
} // namespace ge